#!/usr/bin/env python

import argparse
import os
import json

import numpy as np
import pandas as pd


def build_parser() -> argparse.ArgumentParser:
    ap = argparse.ArgumentParser(
        description=(
            "Build a per-stack health table for T3 using existing outputs "
            "(no prestacking, no internal imports)."
        )
    )
    ap.add_argument(
        "--plateau-csv",
        default="outputs/lensing_plateau.csv",
        help="Path to lensing_plateau.csv (default: outputs/lensing_plateau.csv)",
    )
    ap.add_argument(
        "--windows-json",
        default="outputs/windows.json",
        help="Path to windows.json (default: outputs/windows.json)",
    )
    ap.add_argument(
        "--flatness-json",
        default="outputs/flatness.json",
        help="Path to flatness.json (default: outputs/flatness.json)",
    )
    ap.add_argument(
        "--prestack-meta",
        default=None,
        help=(
            "Optional path to prestacked_meta.csv (stack-level meta: "
            "N_lens, N_src, geometry, etc.). If not provided, meta "
            "columns are omitted."
        ),
    )
    ap.add_argument(
        "--out-csv",
        default="outputs/stack_health.csv",
        help="Output CSV path (default: outputs/stack_health.csv)",
    )
    return ap


def safe_parse_mid(label):
    """
    Parse size-bin labels like '[3,5)' or '[3.0, 5.0)' into a numeric mid.

    Returns NaN on failure.
    """
    if not isinstance(label, str):
        return np.nan
    s = label.strip()
    if not s:
        return np.nan
    # Strip bracket/parenthesis stuff
    for ch in "[]()":
        s = s.replace(ch, "")
    # Now expect something like "3,5" or "3.0, 5.0"
    parts = s.split(",")
    if len(parts) != 2:
        # maybe it's already just a number?
        try:
            return float(s)
        except Exception:
            return np.nan
    try:
        a = float(parts[0])
        b = float(parts[1])
        return 0.5 * (a + b)
    except Exception:
        return np.nan


def load_json_or_empty(path):
    try:
        with open(path, "r") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"[warn] JSON file not found: {path} (continuing without it)")
        return {}
    except json.JSONDecodeError:
        print(f"[warn] Could not parse JSON {path} (continuing without it)")
        return {}


def main():
    ap = build_parser()
    args = ap.parse_args()

    # --- read plateau CSV (core per-stack info) ---
    if not os.path.exists(args.plateau_csv):
        raise SystemExit(
            f"ERROR: plateau CSV not found at {args.plateau_csv}.\n"
            f"Make sure your T3 run has produced lensing_plateau.csv first."
        )
    stacks = pd.read_csv(args.plateau_csv)

    if "stack_id" not in stacks.columns:
        raise SystemExit("ERROR: lensing_plateau.csv missing 'stack_id' column.")

    # stringified stack_id for JSON dict keys
    stacks["stack_id_str"] = stacks["stack_id"].astype(str)

    # --- parse size-bin mid (kpc-ish) ---
    if "R_G_bin" in stacks.columns:
        stacks["RG_mid_bin"] = stacks["R_G_bin"].apply(safe_parse_mid)
    else:
        stacks["RG_mid_bin"] = np.nan

    # --- attach window bounds & width from windows.json ---
    windows = load_json_or_empty(args.windows_json)

    win_i0 = []
    win_i1 = []
    win_nbins = []

    for sid in stacks["stack_id_str"]:
        w = windows.get(str(sid), None)
        if w is None:
            win_i0.append(np.nan)
            win_i1.append(np.nan)
            win_nbins.append(0)
        else:
            i0 = w.get("i0", None)
            i1 = w.get("i1", None)
            win_i0.append(i0)
            win_i1.append(i1)
            if i0 is not None and i1 is not None:
                try:
                    win_nbins.append(int(i1) - int(i0) + 1)
                except Exception:
                    win_nbins.append(0)
            else:
                win_nbins.append(0)

    stacks["win_i0"] = win_i0
    stacks["win_i1"] = win_i1
    stacks["win_nbins"] = win_nbins

    # --- attach flatness metrics (rmse_flat, R2_flat) ---
    # If columns already exist in plateau CSV, just keep them;
    # otherwise try to fill from flatness.json.
    need_rmse = "rmse_flat" not in stacks.columns
    need_R2 = "R2_flat" not in stacks.columns

    if need_rmse or need_R2:
        flatness = load_json_or_empty(args.flatness_json)
        rmse_vals = []
        R2_vals = []
        for sid in stacks["stack_id_str"]:
            f = flatness.get(str(sid), {})
            rmse_vals.append(f.get("rmse_flat", np.nan))
            R2_vals.append(f.get("R2_flat", np.nan))
        if need_rmse:
            stacks["rmse_flat"] = rmse_vals
        if need_R2:
            stacks["R2_flat"] = R2_vals

    # --- amplitude S/N diagnostics from CI ---
    if "A_theta" in stacks.columns:
        if "A_theta_CI_low" in stacks.columns and "A_theta_CI_high" in stacks.columns:
            err = 0.5 * (stacks["A_theta_CI_high"] - stacks["A_theta_CI_low"])
            stacks["A_theta_err"] = err
            # avoid divide-by-zero
            snr = stacks["A_theta"].abs() / err.replace(0, np.nan)
            stacks["A_theta_SNR"] = snr
        else:
            stacks["A_theta_err"] = np.nan
            stacks["A_theta_SNR"] = np.nan
    else:
        stacks["A_theta_err"] = np.nan
        stacks["A_theta_SNR"] = np.nan

    # --- optional: merge stack-level meta if available ---
    meta = None
    if args.prestack_meta:
        if os.path.exists(args.prestack_meta):
            try:
                meta = pd.read_csv(args.prestack_meta)
            except Exception as e:
                print(f"[warn] Could not read meta file {args.prestack_meta}: {e}")
        else:
            print(f"[warn] Meta file {args.prestack_meta} not found; skipping meta merge.")

    if meta is not None and "stack_id" in meta.columns:
        # one row per stack_id; if meta has per-annulus rows, drop duplicates
        meta_stack = meta.drop_duplicates(subset=["stack_id"]).copy()
        # Merge on stack_id; keep all stacks even if meta is missing
        stacks = stacks.merge(
            meta_stack,
            on="stack_id",
            how="left",
            suffixes=("", "_meta"),
        )

    # --- clean up helper column ---
    stacks = stacks.drop(columns=["stack_id_str"])

    # --- ensure outputs directory exists, then write ---
    out_dir = os.path.dirname(args.out_csv)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    stacks.to_csv(args.out_csv, index=False)
    print(f"[info] Wrote {len(stacks)} rows to {args.out_csv}")


if __name__ == "__main__":
    main()
